%% Calculate P_hat

function [P_hat_mnj,b_hat,b_hat_f,P_M_hat,P_M_f_hat] = PriceLoopCES_fun_approx(P_hat_mnj0,w_hat,pi_mnj,...
       pi_M, pi_l,pi_mnj_sorted,pi_M_f,pi_l_f)


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Author: Produced for JdG,AL,IM by Christopher Evans at UPF 
%
% Solves the relative change in prices. This function takes wage and price
% guess to calculate the price change due to shock. This function will be
% looped over. 
%
% NEW: For all code, we will change the subscripts to match those in
% master_notes3.pdf and the draft. I.e., we will change to {mn,ij} pairs,
% which represent country pair {mn} and sector pair {ij}. Global variables 
% changed and slight modification to file structure for calling up data.
% We also get rid of gamma_type variable. What previously _2.csv. files
% slightly as well as adjusting in our new directory.
%
% MAJOR CHANGE: Solution of price index now based on CES cost function,
% which implies a change of computation of price function in terms of
% functional form. Will refer to master_notes3 for formula and solve for.
%
% This function will also be used for both the baseline setting as well as
% the heterogeneous calibration. Therefore, the initial labor and
% intermediate import shares will have to adjust away from the data (alpha,
% gamma) in the inital wedge exercise. Therefore, the function is set up to
% run with a pi_l and pi_M rather than data shares.
%
% b_hat (cost funtion change), which is a function of w_hat and P^M_hat,
% which in turn is a function of P_mnj_hat. Given formulae, we do not need
% to update for labor and intermediate good shares in this function in
% solivng for the bilateral price change. Will also output to use in future
% loops
%
% We also get rid of using the .(ISO{}) cell call up for the gammas and
% instead use a large matrix saved in Step2. This will help save on loops
% and speed up the proceedure.
%
% Last Updated: 5/02/2019 by IM & JDG
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

global alphaT rhoT lambdaT etaT rho sigma sigmaT lambda eta varphi ...
    tol maxit Pi_hat_n D_hat_n Xi_hat_mnjf Xi_hat_mnj a_hat_f a_hat ...
    N J FI_J france countrysecD firmsecD_sorted start start_sorted ...
    sum_by_country_dummy ISO ... 
    sL_n sPi_n sD_n alpha_WIOD_j alpha_WIOD_francej labour_share_PWT
 
pfmax = 1;
it    = 1;   

% Initialize outcome variables to initial valued fed into function

P_hat_knj = P_hat_mnj0;

fprintf('   PRICE ITERATION=%d \n',it);

while (it <= maxit) && (pfmax > tol)   
   %% Sector-level operations
   
   % calculating the cost of input bundle, b 
   
   % Calculating P_hat for all countries
   
   % P_hat_mnj: extending the matrix by J for intial prices
   log_P_hat_mnj=log(P_hat_knj);
   log_P_hat_mnj0_repmat = kron(log_P_hat_mnj,ones(1,J));
   
   % P_hat_weta = pi_M.*(P_hat_mnj_repmat)^(1-eta);

   log_P_hat_weta = pi_M.*log_P_hat_mnj0_repmat;
   
   % Summing by k and by i, so we just need to sum along the rows, this
   % will give us a transposed vector [ 1xJ*N ] and the price of
   % intermediate goods ==> P_M_hat
   % Need to adjust for fact that some country-sector pairs never import,
   % so will ahve a 0 P_M_hat raised to a negative number.
   
   log_P_M_hat = squeeze(sum(log_P_hat_weta,1));
   P_M_hat = exp(log_P_M_hat);
   
   % Multiplying by the intermediate share afer raising to 1-lambda
   
   log_P_M_hat_wlambda = log_P_M_hat.*(1-pi_l)';  
   
   % Repmat so that we have it by n, this means extending it by N
   
   log_P_M_hat_wlambda_repmat = repmat(log_P_M_hat_wlambda', [1 N]);
   
   % Need to expand to mnj, but have to keep order so use a dummy (cant
   % just use repmat here).
   log_w_hat = log(w_hat);
   log_w_hat_to_vector=countrysecD*log_w_hat';
    
   % Taking wage to power of 1-lambda and multiplying by labor share
   
   log_w_hat_wlambda = pi_l.*log_w_hat_to_vector;
   
   % Expanding the wage term by N so it is a [N*JxN]
   % matrix
   
   log_w_hat_wlambda_repmat = repmat(log_w_hat_wlambda,[1 N]);
   
   % Create total cost: (N*J)*N matrix
   
   log_b_hat_rep = (log_w_hat_wlambda_repmat + log_P_M_hat_wlambda_repmat);   

   % To output for later use -- non-repeated vector
   b_hat_rep = exp(log_b_hat_rep);
   b_hat = b_hat_rep(:,1);

   % Assembling components into P_hat_mnj given b matrix, shocks and  
   % bilateral shares of sectors/firms
   
   % Generate matrices that are needed for shocks and parameters
   
   a_hat_repmat = repmat(a_hat,[1 N]);
   rho_mnj_repmat = repmat(rho, [N N]);
   
   P_hat_mnj = (pi_mnj.*Xi_hat_mnj.*(b_hat_rep.*a_hat_repmat)...
      .^(1-rho_mnj_repmat)).^(1./(1-rho_mnj_repmat));
   
    
   %% French firm-level operations
   
   % Key difference here is that we must deal with firm-specific labor,
   % import, and export shares as well as potential idiosyncratic taste 
   % and productivity shocks. Unlike C-D production, these factors come in
   % when constructing initial cost function, so need to do everything from
   % scratch following similar code as above but using firm-level data

   % Initialising a matrix we will use to sum over k countries for France
   % This is done for the intermediate price index P^M
   log_P_hat_mnj=log(P_hat_knj(:,france)');
   log_P_hat_mnjf0_repmat = kron(log_P_hat_mnj,ones(FI_J,1));
   
   
   % P_hat_f_weta = pi_M.*(P_hat_mnj_repmat)^(1-\eta);
   
   log_P_hat_f_weta = pi_M_f.*log_P_hat_mnjf0_repmat;
   
   % P_M_f_hat = sum_k_pi_M_P_f_hat: Sums by countries and sectors
   log_P_M_f_hat = sum(log_P_hat_f_weta,2);
   P_M_f_hat = exp(log_P_M_f_hat);
   
%    P_M_f_hat = zeros(size(firmsecD_sorted,1),1);
% 
%    % for m=france
% 
%    for k=1:N
%       % Price_hat_by_firm: Taking price charged by exporting country-sector 
%       % (eg AUS and sector i=1) exporting to France sector j=1 to J. Repmat
%       % allows us to replicate this vector of costs to each firm.
%       %
%       % We need to apply loops to intermediate goods that are sourced by
%       % exporting country to loop. So, before P_hat_mnjf, we need to figure
%       % out P_M_hat_mjf, which is based on P_hat_mnj initial so some repeat
%       % from above, where will use 0 price variables ot initialize
%     
%       P_hat_mnjf_0 = ...
%           repmat(P_hat_knj(start(k):start(k+1)-1,france)',[FI_J 1]); 
% 
%       % P_hat_f_weta = pi_M.*(P_hat_mnj_repmat)^(1-\eta);
%    
%       P_hat_f_weta = pi_M_f.(ISO{1,k}).*P_hat_mnjf_0.^(1-eta);
% 
%       % sum_k_pi_M_P_f_hat: Sums by sector i 
%       
%       sum_k_P_hat_f_weta = sum(P_hat_f_weta,2);
%       
%       % Summing over countries (k), the sum_k_P_hat_f_weta is
%       % initialised as zero then each time it is added to by 
%       % sum_k_gamma_price_hat ==> P_M_f_hat
%      
%       P_M_f_hat = P_M_f_hat+sum_k_P_hat_f_weta;      
%    end
   
%    % Raise to 1/(1-eta)
%    
%    P_M_f_hat = P_M_f_hat.^(1/(1-eta));
    
   % Multiplying by the one minus alpha afer raising to 1-lambda
   
   log_P_M_f_hat_wlambda = log_P_M_f_hat.*(1-pi_l_f);
    
   % Clearing to save memory 
   
   clear P_hat_mnjf_0 P_hat_f_weta sum_k_P_hat_f_weta 
   
   % Taking wage to power of 1-lambda and multiplying by labor share
   log_w_hat_f = log(w_hat(france));
   log_w_hat_f_wlambda = pi_l_f.*repmat(log_w_hat_f,[FI_J 1]);

   % Create total cost: F*1 matrix
   
   log_b_hat_f = (log_w_hat_f_wlambda + log_P_M_f_hat_wlambda);
   b_hat_f = exp(log_b_hat_f);
   % Generate firm-specific prices for given country-sector destination

    % Multiplying the the expression by a technology (a) shock 
   rho_f=firmsecD_sorted*rho;
   inside_brackets = (b_hat_f.*a_hat_f).^(1-rho_f);

   % Multiplying the previous expression by pi_mnj(f) Xi_hat_mnj(f)
      
   pi_xi_brackets = pi_mnj_sorted.*...
          Xi_hat_mnjf.*repmat(inside_brackets, [1 N]);
   
   % Summing by using the firmsecD_sorted matrix
   
    sum_pi_xi_brackets=(firmsecD_sorted')*pi_xi_brackets;
   
   % looping over sectors to put to the power of rho. This will give us the
   % price hat in france.
   
    for j=1:J
      P_hat_mnj_france(j,:)=sum_pi_xi_brackets(j,:).^(1/(1-rho(j)));
    end

   % Replacting the price hat for france with the price hat calculated
   % using firm data
   
   P_hat_mnj(start(france):start(france+1)-1,:)=P_hat_mnj_france;
    
   %% Ending price loop
   
   % This should not be necessary, it is to make sure we don't have any NaNs,
   % however if NaNs feature then there is an issue in the matlab files
   % preceding this one
   
   P_hat_mnj(~isfinite(P_hat_mnj))=1;
   
   % checking distance from new calculating of prices and the initial guess
   
   % pfdev: Calculating the deviation from the guess and updated price change
   
   P_hat_mnj_vec=reshape(P_hat_mnj,[N*N*J 1]);
   P_hat_knj_vec=reshape(P_hat_knj,[N*N*J 1]);
      
   pfdev=abs(P_hat_knj_vec-P_hat_mnj_vec)./P_hat_knj_vec;
   
   pfmax=norm(pfdev);
   
   % Updated the price guess
   
   P_hat_knj = P_hat_mnj;

%    % Small interation price relative to previous step given a nu_P, where
%    % nu_P is the adjustment we might make
%    
%    nu_P = .5;
%    
%    p_diff = (P_hat_knj_vec-P_hat_mnj_vec)./P_hat_knj_vec;
% 
%    P_hat_knj_vec = P_hat_mnj_vec + nu_P*P_hat_mnj_vec.*p_diff; % Iteration of new price
%    
%    p_diff1 = (P_hat_knj_vec-P_hat_mnj_vec)./P_hat_knj_vec;
%    
%    P_hat_knj = reshape(P_hat_knj_vec ,[N*J,N]);
%    
%    % Calculate distance as max between non-adjusted and adjusted new P_hat 
%    % and old P_hat
% 
%    pfmax = max(max(abs(p_diff)),max(abs(p_diff1)));
 
   % it: Iteration number
   it       = it + 1;
   
   fprintf('   PRICE ITERATION=%d        Max price distance=%d  \n',it-1,pfmax);
end

%% Generate final output after convergence

P_hat_mnj = P_hat_knj;

% b_hat

P_hat_mnj0_repmat = kron(P_hat_knj,ones(1,J));

% P_hat_weta = pi_M.*(P_hat_mnj_repmat)^(1-eta);
log_P_hat_mnj=log(P_hat_knj);
log_P_hat_mnj0_repmat = kron(log_P_hat_mnj,ones(1,J));
log_P_hat_weta = pi_M.*log_P_hat_mnj0_repmat;
log_P_M_hat = squeeze(sum(log_P_hat_weta,1));
P_M_hat = exp(log_P_M_hat);   
P_M_hat(isinf(P_M_hat)) = 1;

% Multiplying by the intermediate share afer raising to 1-lambda
   
log_P_M_hat_wlambda = log_P_M_hat.*(1-pi_l)';  
log_P_M_hat_wlambda_repmat = repmat(log_P_M_hat_wlambda', [1 N]);
log_w_hat = log(w_hat);
log_w_hat_to_vector=countrysecD*log_w_hat';
log_w_hat_wlambda = pi_l.*log_w_hat_to_vector;
log_w_hat_wlambda_repmat = repmat(log_w_hat_wlambda,[1 N]);

% Create total cost: (N*J)*N matrix

log_b_hat_rep = (log_w_hat_wlambda_repmat + log_P_M_hat_wlambda_repmat);
b_hat_rep = exp(log_b_hat_rep);
   
% To output for later use -- non-repeated vector
b_hat = b_hat_rep(:,1);

% P_M_f_hat and b_hat_f
P_M_f_hat = zeros(size(firmsecD_sorted,1),1);
log_P_hat_mnj=log(P_hat_knj(:,france)');
log_P_hat_mnjf0_repmat = kron(log_P_hat_mnj,ones(FI_J,1));

% P_hat_f_weta = pi_M.*(P_hat_mnj_repmat)^(1-\eta);

log_P_hat_f_weta = pi_M_f.*log_P_hat_mnjf0_repmat;

% P_M_f_hat = sum_k_pi_M_P_f_hat: Sums by countries and sectors

log_P_M_f_hat = sum(log_P_hat_f_weta,2);
P_M_f_hat = exp(log_P_M_f_hat);


% for k=1:N
%    P_hat_mnjf_0 = repmat(P_hat_knj(start(k):start(k+1)-1,france)',[FI_J 1]); 
% 
%    % P_hat_f_weta = pi_M.*(P_hat_mnj_repmat)^(1-\eta);
% 
%    P_hat_f_weta = pi_M_f.(ISO{1,k}).*P_hat_mnjf_0.^(1-eta);
% 
%    % sum_k_pi_M_P_f_hat: Sums by sector i 
% 
%    sum_k_P_hat_f_weta = sum(P_hat_f_weta,2);
% 
%    % Summing over countries (k), the sum_k_P_hat_f_weta is
%    % initialised as zero then each time it is added to by 
%    % sum_k_gamma_price_hat ==> P_M_f_hat
% 
%    P_M_f_hat = P_M_f_hat+sum_k_P_hat_f_weta;      
% end

% Raise to 1/(1-eta)

% P_M_f_hat = P_M_f_hat.^(1/(1-eta));
% disp('P_M_f_hat')
% [min(P_M_f_hat) max(P_M_f_hat)] 

% Multiplying by the one minus alpha afer raising to 1-lambda

log_P_M_f_hat_wlambda = log_P_M_f_hat.*(1-pi_l_f);

% Clearing to save memory 

clear P_hat_mnjf_0 P_hat_f_weta sum_k_P_hat_f_weta 

% Taking wage to power of 1-lambda and multiplying by labor share

log_w_hat_f = log(w_hat(france));
log_w_hat_f_wlambda = pi_l_f.*repmat(log_w_hat_f,[FI_J 1]);
   
% Create total cost: F*1 matrix
log_b_hat_f = (log_w_hat_f_wlambda + log_P_M_f_hat_wlambda);
b_hat_f = exp(log_b_hat_f);
   
% disp('b_hat_f')
% [min(b_hat_f) max(b_hat_f)] 
end